KarL05/Aiyiyi's Blog
WC2021 括号路径题解

题目难度: 省选

题目分类: 图论

题目大意: [WC2021]括号路径

有一图, \(n\) 点, \(2m\)

定义 \([u,v,k]\)\(u\)\(v\) 边权值为 \(k\), 当 \(k>0\) 时, \([u,v,k]\) 是一个 \(k\) 类型的左括号, 当 \(k<0\) 时, \([u,v,k]\) 是一个 \(k\) 类型的右括号

对于一个路径, 其对应括号序列是每条边依次被映射成对应的括号所产生的序列

若图中存在 \([u,v,k]\), 则存在 \([v,u,-k]\)

定义 \((u,v)\) 为括号路径当存在一条从 \(u\)\(v\) 的路径使得对应的括号序列是一个合法的括号序列

求有多少点对 \((u,v)\) 是一个合法的括号序列

\(n≤m ≤ 10^5\)

解题思路:

经典并查集

题目答案:

1: 若 (a,b), 则 (b,a)

2: 若 (a,b) 且 (b,c) , 则 (a,c)

通过 2 可以得到, 每个联通块对答案的贡献是 \(\frac{n(n+1)}{2}\)

3: 若 \((a,b)\)\([u,a,-k]\)\([v,b,-k]\), 则 \(u\) \(v\) 在新图上有边

使用并查集维护连通性, 给每个联通块锁成一个点, 但对于每种边保留一个类似于触手的特征边来辅助合并, 使用哈希表来做到 \(O(n\log n)\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include"bits/stdc++.h"
using namespace std;

int read () {
char c = getchar();
int x = 0;
while (c<'0'||c>'9') c = getchar();
while ('0'<=c&&c<='9') {
x = x*10+c-'0';
c = getchar();
}
return x;
}

int n,m,k;
const int maxn = 1e6;
map<int,int> E[maxn];
int C[maxn];
int ind[maxn];
int cnt[maxn];

struct Node {
int x;
int y;
};

int f[maxn];
int sz[maxn];

int _find (int x) {
if (f[x]==x) return x;
f[x] = _find(f[x]);
return f[x];
}

queue<Node> q;

void _merge (int a, int b) {
int fa = _find(a);
int fb = _find(b);
if (fa==fb) return;
if (sz[fa]<sz[fb]) swap(fa,fb);
map<int,int>::iterator It;
for (It=E[fb].begin();It!=E[fb].end();It++) {
int c = It->first;
int to = It->second;
if (E[fa][c]) q.push({E[fa][c],to});
else E[fa][c] = to;
}
f[fb] = fa;
sz[fa] += sz[fb];
}

int main () {
n = read();
m = read();
k = read();
for (int i=1;i<=n;i++) {
f[i] = i;
sz[i] = 1;
}
for (int i=1;i<=m;i++) {
int u = read();
int v = read();
int w = read();
if (E[v][w]) q.push({E[v][w],u});
else E[v][w] = u;
}
while (!q.empty()) {
int x = q.front().x;
int y = q.front().y;
q.pop();
_merge(x,y);
}
for (int i=1;i<=n;i++) {
f[i] = _find(i);
cnt[f[i]]++;
}
long long ans = 0;
for (int i=1;i<=n;i++) {
ans += 1ll*cnt[i]*(cnt[i]-1)/2;
}
cout<<ans<<endl;
}